"""
graph_utils.symmetrize
----------------------

Functions to symmetrize a directed edge list (rows, cols, vals) for graph construction.

Supports:
- 'mean' (average weight both directions)
- 'max', 'min', 'umap', 'geom', 'harm'
"""

import numpy as np
from scipy.sparse import coo_matrix, csr_matrix

def symmetrize_graph(
    rows: np.ndarray,
    cols: np.ndarray,
    vals: np.ndarray,
    n_samples: int,
    symmetrize: str = "max"
):
    """
    Symmetrize a directed kNN graph according to the specified rule.

    Parameters
    ----------
    rows : np.ndarray
        Row indices of directed edges.
    cols : np.ndarray
        Col indices of directed edges.
    vals : np.ndarray
        Edge weights.
    n_samples : int
        Number of samples (graph size).
    symmetrize : str
        Symmetrization rule: 'mean', 'max', 'min', 'umap', 'geom', 'harm'.

    Returns
    -------
    sym_rows : np.ndarray
    sym_cols : np.ndarray
    sym_vals : np.ndarray
        Symmetrized edge list in COO format.
    """
    # if symmetrize == "mean":
    #     I = np.concatenate([rows, cols])
    #     J = np.concatenate([cols, rows])
    #     V = np.concatenate([vals * 0.5, vals * 0.5])
    #     S = coo_matrix((V, (I, J)), shape=(n_samples, n_samples)).tocsr()
    #     S.sum_duplicates()


    if symmetrize == "mean":
        # both directions
        I = np.concatenate([rows, cols])
        J = np.concatenate([cols, rows])
        V = np.concatenate([vals, vals])
        S_num = coo_matrix((V, (I, J)), shape=(n_samples, n_samples)).tocsr()
        S_num.sum_duplicates()

        # indicator counts (denominator)
        V_ind = np.ones_like(V, dtype=np.float32)
        S_den = coo_matrix((V_ind, (I, J)), shape=(n_samples, n_samples)).tocsr()
        S_den.sum_duplicates()

        # --- elementwise division, sparse-only ---
        S = S_num.copy()
        num = S_num.data
        den = S_den.data.copy()
        den[den == 0] = 1  # avoid div-by-zero
        S.data = num / den
    else:
        A = coo_matrix((vals, (rows, cols)), shape=(n_samples, n_samples)).tocsr()
        if symmetrize == "max":
            S = A.maximum(A.T)
        elif symmetrize == "min":
            S = A.minimum(A.T)
        elif symmetrize == "umap":
            S = A + A.T - A.multiply(A.T)
        elif symmetrize == "harm":
            eps = 1e-8
            prod = A.multiply(A.T)
            denom = A + A.T
            S = prod.multiply(2.0).multiply(1.0 / (denom + eps))
        elif symmetrize == "geom":
            S = (A.multiply(A.T))  # geometric fusion
        else:
            raise ValueError(f"Unknown symmetrize mode: {symmetrize!r}")

    S_coo = S.tocoo(copy=False)
    return S_coo.row.astype(np.int32), S_coo.col.astype(np.int32), S_coo.data.astype(np.float32)

__all__ = ["symmetrize_graph"]